import hydra
import torch
import os
from omegaconf import DictConfig, OmegaConf

from Initializers.data import initialize_data
from Initializers.init_utils import init_logistics, init_loading
from Initializers.initialize_models import initialize_models
from Initializers.wandb_logging_policy import initialize_wandb
from Policy.gcrl_trainer import GCTrainer


@hydra.main(version_base=None, config_path="./configs", config_name="config")
def rollout_model(config: DictConfig):
    # rolls out a model for init_random_step steps, for evaluation and visualiztion generation
    config = OmegaConf.structured(OmegaConf.to_yaml(config))
    # recovers the hydra config, which reads from ./configs/config.yaml (modify the hydra parameters)

    # create the environment and assign the number of variables and observation shape
    single_env, train_env, test_env, norm, logger, config = init_logistics(config, wdb_run=None)


    # Similar to code in tianshou.examples
    config.device = torch.device(f"cuda:{config.cuda_id}" if torch.cuda.is_available() else "cpu")

    # create the models, this is the most involved part of initialization
    dynamics, graph_encoding, reward, policy = initialize_models(config, single_env, norm, None)
    
    # initializes a vectored buffer (for the number of environments),
    # and the train and test collectors (handling the action-observation-reward loop)
    train_collector, test_collector, buffer = \
        initialize_data(config, policy, dynamics, single_env, train_env, test_env, norm)

    # loads a buffer from memory if necessary
    buffer = init_loading(config,
                            dynamics,
                            graph_encoding,
                            policy,
                            buffer)

    # Need to assign the buffer to the collector
    train_collector.buffer = buffer

    try:
        for i in range(len(train_env)):
            os.makedirs(os.path.join(config.save.save_frames, str(i)))
    except OSError as e:
        pass

    train_collector.collect(n_step=config.train.init_random_step, random=False, show_frame=False, save_frame_dir=config.save.save_frames) # use init random step to collect data

if __name__ == "__main__":
    rollout_model()
